{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.163323Z",
     "start_time": "2019-01-29T22:48:53.680455Z"
    }
   },
   "outputs": [],
   "source": [
    "%load ../../utils/djl-imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T19:41:18.204791Z",
     "start_time": "2019-01-29T19:41:18.201091Z"
    }
   },
   "source": [
    "# 1. Multinomial Sampling\n",
    "\n",
    "Implement a sampler from a discrete distribution from scratch, mimicking the function `manager.randomMultinomial()`. Its arguments should be a vector of probabilities $p$. You can assume that the probabilities are normalized, i.e. tha they sum up to $1$. Make the call signature as follows:\n",
    "\n",
    "```\n",
    "samples = sampler(probs, shape) \n",
    "\n",
    "probs   : A float array of size n of nonnegative numbers summing up to 1\n",
    "shape   : Shape object declaring dimensions for the output\n",
    "samples : Samples from probs with shape matching shape\n",
    "```\n",
    "\n",
    "Hints:\n",
    "\n",
    "1. Use `manager.randomUniform()` to get a sample from $U[0,1]$.\n",
    "1. You can simplify things for `probs` by computing the cumulative sum over `probs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.185124Z",
     "start_time": "2019-01-29T22:48:56.165645Z"
    }
   },
   "outputs": [],
   "source": [
    "NDArray sampler(float[] probs, Shape shape) {\n",
    "    // Add your code here\n",
    "    NDManager manager = NDManager.newBaseManager();\n",
    "    return manager.zeros(shape);\n",
    "}\n",
    "\n",
    "// A simple test\n",
    "sampler(new float[]{0.2f, 0.3f, 0.5f}, new Shape(2,3));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Central Limit Theorem\n",
    "\n",
    "Let's explore the Central Limit Theorem when applied to text processing. \n",
    "\n",
    "* Download [https://www.gutenberg.org/ebooks/84](https://www.gutenberg.org/files/84/84-0.txt) from Project Gutenberg \n",
    "* Remove punctuation, uppercase / lowercase, and split the text up into individual tokens (words).\n",
    "* For the words `a`, `and`, `the`, `i`, `is` compute their respective counts as the book progresses, i.e. \n",
    "    $$n_\\mathrm{the}[i] = \\sum_{j = 1}^i \\{w_j = \\mathrm{the}\\}$$\n",
    "* Plot the proportions $n_\\mathrm{word}[i] / i$ over the document in one plot.\n",
    "* Find an envelope of the shape $O(1/\\sqrt{i})$ for each of these five words. (Hint, check the last page of the [sampling notebook](http://courses.d2l.ai/berkeley-stat-157/slides/1_24/sampling.pdf))\n",
    "* Why can we **not** apply the Central Limit Theorem directly? \n",
    "* How would we have to change the text for it to apply? \n",
    "* Why does it still work quite well?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.200892Z",
     "start_time": "2019-01-29T22:48:56.188145Z"
    }
   },
   "outputs": [],
   "source": [
    "URL url = new URL(\"https://www.gutenberg.org/files/84/84-0.txt\");\n",
    "Scanner s = new Scanner(url.openStream());\n",
    "ArrayList<String> book = new ArrayList();\n",
    "while (s.hasNext()) {\n",
    "    book.add(s.next());\n",
    "}\n",
    "for (int i = 0; i < 10; i++) {\n",
    "    System.out.println(book.get(i));\n",
    "}\n",
    "\n",
    "// Add your code here"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Denominator-layout notation\n",
    "\n",
    "We used the numerator-layout notation for matrix calculus in class, now let's examine the denominator-layout notation.\n",
    "\n",
    "Given $x, y\\in\\mathbb R$, $\\mathbf x\\in\\mathbb R^n$ and $\\mathbf y \\in \\mathbb R^m$, we have\n",
    "\n",
    "$$\n",
    "\\frac{\\partial y}{\\partial \\mathbf{x}}=\\begin{bmatrix}\n",
    "\\frac{\\partial y}{\\partial x_1}\\\\\n",
    "\\frac{\\partial y}{\\partial x_2}\\\\\n",
    "\\vdots\\\\\n",
    "\\frac{\\partial y}{\\partial x_n}\n",
    "\\end{bmatrix},\\quad \n",
    "\\frac{\\partial \\mathbf y}{\\partial {x}}=\\begin{bmatrix}\n",
    "\\frac{\\partial y_1}{\\partial x}, \n",
    "\\frac{\\partial y_2}{\\partial x}, \n",
    "\\ldots,\n",
    "\\frac{\\partial y_m}{\\partial x}\n",
    "\\end{bmatrix}\n",
    "$$\n",
    "\n",
    "and \n",
    "\n",
    "$$\n",
    "\\frac{\\partial \\mathbf y}{\\partial \\mathbf{x}}\n",
    "=\\begin{bmatrix}\n",
    "\\frac{\\partial \\mathbf y}{\\partial {x_1}}\\\\\n",
    "\\frac{\\partial \\mathbf y}{\\partial {x_2}}\\\\\n",
    "\\vdots\\\\\n",
    "\\frac{\\partial \\mathbf y}{\\partial {x_3}}\\\\\n",
    "\\end{bmatrix}\n",
    "=\\begin{bmatrix}\n",
    "\\frac{\\partial y_1}{\\partial x_1}, \n",
    "\\frac{\\partial y_2}{\\partial x_1},\n",
    "\\ldots,\n",
    "\\frac{\\partial y_m}{\\partial x_1}\n",
    "\\\\ \n",
    "\\frac{\\partial y_1}{\\partial x_2},\n",
    "\\frac{\\partial y_2}{\\partial x_2},\n",
    "\\ldots,\n",
    "\\frac{\\partial y_m}{\\partial x_2}\\\\ \n",
    "\\vdots\\\\\n",
    "\\frac{\\partial y_1}{\\partial x_n},\n",
    "\\frac{\\partial y_2}{\\partial x_n},\n",
    "\\ldots,\n",
    "\\frac{\\partial y_m}{\\partial x_n}\n",
    "\\end{bmatrix}\n",
    "$$\n",
    "\n",
    "Questions: \n",
    "\n",
    "1. Assume $\\mathbf  y = f(\\mathbf u)$ and $\\mathbf u = g(\\mathbf x)$, write down the chain rule for $\\frac {\\partial\\mathbf  y}{\\partial\\mathbf x}$\n",
    "2. Given $\\mathbf X \\in \\mathbb R^{m\\times n},\\ \\mathbf w \\in \\mathbb R^n, \\ \\mathbf y \\in \\mathbb R^m$, assume $z = \\| \\mathbf X \\mathbf w - \\mathbf y\\|^2$, compute $\\frac{\\partial z}{\\partial\\mathbf w}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Numerical Precision\n",
    "\n",
    "Given scalars `x` and `y`, implement the following `logExp()` function such that it returns \n",
    "$$-\\log\\left(\\frac{e^x}{e^x+e^y}\\right)$$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.206890Z",
     "start_time": "2019-01-29T22:48:56.202996Z"
    }
   },
   "outputs": [],
   "source": [
    "import java.util.function.BinaryOperator; \n",
    "\n",
    "// Here we wrap the function in a class\n",
    "// so that we can pass its reference to a function\n",
    "static class Function {\n",
    "    static NDArray logExp(NDArray x, NDArray y) {\n",
    "        // Add your solution here\n",
    "        NDManager manager = NDManager.newBaseManager();\n",
    "        return manager.zeros(new Shape(1));\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test your codes with normal inputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.215579Z",
     "start_time": "2019-01-29T22:48:56.209659Z"
    }
   },
   "outputs": [],
   "source": [
    "var x = manager.create(new float[]{2});\n",
    "var y = manager.create(new float[]{3});\n",
    "var z = Function.logExp(x, y);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now implement a function to compute $\\partial z/\\partial x$ and $\\partial z/\\partial y$ with a `GradientCollector`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.223303Z",
     "start_time": "2019-01-29T22:48:56.218056Z"
    }
   },
   "outputs": [],
   "source": [
    "void grad(BinaryOperator<NDArray> forwardFunction, \n",
    "          NDArray x, NDArray y) {\n",
    "    // Add your code here\n",
    "    // Note: This will throw an error \n",
    "    // if you try to run this in its present form\n",
    "    // since the gradient for each NDArray \n",
    "    // has not yet been calculated.\n",
    "    System.out.printf(\"Gradient of x = \");\n",
    "    System.out.println(x.getGradient());\n",
    "    System.out.printf(\"Gradient of y = \");\n",
    "    System.out.println(y.getGradient());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test your codes, it should print the results nicely. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.267165Z",
     "start_time": "2019-01-29T22:48:56.227035Z"
    }
   },
   "outputs": [],
   "source": [
    "x.setRequiresGradient(true);\n",
    "y.setRequiresGradient(true);\n",
    "grad(Function::logExp, x, y);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But now let's try some \"hard\" inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.285842Z",
     "start_time": "2019-01-29T22:48:56.274079Z"
    }
   },
   "outputs": [],
   "source": [
    "x = manager.create(new float[]{50});\n",
    "y = manager.create(new float[]{100});\n",
    "\n",
    "x.setRequiresGradient(true);\n",
    "y.setRequiresGradient(true);\n",
    "\n",
    "grad(Function::logExp, x, y);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Does your code return correct results? If not, try to understand the reason. (Hint, evaluate `exp(100)`). Now develop a new function `stableLogExp()` that is identical to `logExp()` in math, but returns a more numerical stable result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-01-29T22:48:56.305595Z",
     "start_time": "2019-01-29T22:48:56.293399Z"
    }
   },
   "outputs": [],
   "source": [
    "static class Function {\n",
    "    static NDArray stableLogExp(NDArray x, NDArray y) {\n",
    "        // Add your code here\n",
    "        return null;\n",
    "    }\n",
    "}\n",
    "\n",
    "grad(Function::stableLogExp, x, y);"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
